import copy as cp

from qubiter.adv_applications.MeanHamil import *
from qubiter.device_specific.Qubiter_to_RigettiPyQuil import *
from qubiter.device_specific.RigettiTools import *
import qubiter.utilities_gen as utg

from openfermion.ops import QubitOperator

from pyquil.quil import Program, Pragma
from pyquil.gates import *
from pyquil.api import WavefunctionSimulator


class MeanHamil_rigetti(MeanHamil):
    """
    This class is a child of MeanHamil.

    This class uses either Rigetti's real hardware or virtual simulator to
    calculate mean values. `qc` returned by Rigetti's get_qc() method is
    passed in as an input to the constructor of this class. If num_samples
    !=0, the class uses qc.run() to calculate mean values. If num_samples=0,
    the class ignores the `qc` input and uses PyQuil's WavefunctionSimulator
    to calculate mean values exactly.


    Attributes
    ----------
    do_resets : bool
    pg : Program
        object of PyQuil class `Program`
    qc : QuantumComputer
        returned by PyQuil method get_qc()
    term_to_exec : dict[]
        maps a term to an executable. QubitOperator from OpenFermion has
        attribute `terms` which is a dict from a term to a coefficient. An
        executable is the output of PyQuil's compile() method.
    translation_line_list : list[str]
        a list of lines of PyQuil code generated by the translator. The
        lines all start with "pg +="
    translator : Qubiter_to_RigettiPyQuil

    """

    def __init__(self, qc, file_prefix, num_qbits, hamil,
                all_var_nums, fun_name_to_fun,
                do_resets=True, **kwargs):
        """
        Constructor

        Do in constructor as much hamil indep stuff as possible so don't
        have to redo it with every call to cost fun. Also,
        when self.num_samples !=0,  we store a dict called term_to_exec
        mapping an executable (output of Rigetti compile() function) to a
        term,  for each term in the hamiltonian hamil. When num_samples=0,
        term_to_exec={}

        Parameters
        ----------
        qc : QuantumComputer
        file_prefix : str
        num_qbits : int
        hamil : QubitOperator
        all_var_nums : list[int]
        fun_name_to_fun : dict[str, function]
        do_resets : bool
        kwargs : dict
            key-words args of MeanHamilMinimizer constructor

        Returns
        -------

        """

        MeanHamil.__init__(self, file_prefix, num_qbits, hamil,
                           all_var_nums, fun_name_to_fun, **kwargs)
        self.qc = qc
        self.do_resets = do_resets

        # this creates a file with all PyQuil gates that
        # are independent of hamil. Gates may contain free parameters
        self.translator = Qubiter_to_RigettiPyQuil(
            self.file_prefix, self.num_qbits,
            aqasm_name='RigPyQuil', prelude_str='', ending_str='')
        with open(utg.preface(self.translator.aqasm_path), 'r') as fi:
            self.translation_line_list = fi.readlines()

        pg = Program()
        self.pg = pg
        if self.num_samples:

            # pg prelude
            pg += Pragma('INITIAL_REWIRING', ['"PARTIAL"'])
            if self.do_resets:
                pg += RESET()
            ro = pg.declare('ro', 'BIT', self.num_qbits)
            s = ''
            for var_num in self.all_var_nums:
                vname = self.translator.vprefix + str(var_num)
                s += vname
                s += ' = pg.declare("'
                s += vname
                s += '", memory_type="REAL")\n'
            exec(s)

            # add to pg the operations that are independent of hamil
            for line in self.translation_line_list:
                line = line.strip('\n')
                if line:
                    exec(line)

            len_pg_in = len(pg)

            # hamil loop to store executables for each term in hamil
            self.term_to_exec = {}
            for term, coef in self.hamil.terms.items():

                # reset pg to initial length.
                # Temporary work-around to bug
                # in PyQuil ver 2.5.0.
                # Slicing was changing
                # pg from type Program to type list
                pg = Program(pg[:len_pg_in])
                self.pg = pg

                # add xy measurements coda to pg
                bit_pos_to_xy_str =\
                    {bit: action for bit, action in term if action != 'Z'}
                RigettiTools.add_xy_meas_coda_to_program(
                    pg, bit_pos_to_xy_str)

                # request measurements
                for i in range(self.num_qbits):
                    pg += MEASURE(i, ro[i])

                pg.wrap_in_numshots_loop(shots=self.num_samples)

                executable = self.qc.compile(pg)
                # print(",,,...", executable)
                self.term_to_exec[term] = executable

    def get_mean_val(self, var_num_to_rads):
        """
        This method returns the empirically determined Hamiltonian mean
        value. It takes as input the values of placeholder variables. It
        passes those values into the Rigetti method run() when num_samples
        !=0. When num_samples=0, WavefunctionSimulator is used to calculate
        the output mean value exactly.

        Parameters
        ----------
        var_num_to_rads : dict[int, float]

        Returns
        -------
        float

        """
        # hamil loop
        mean_val = 0
        for term, coef in self.hamil.terms.items():
            # we have checked before that coef is real
            coef = complex(coef).real

            vprefix = self.translator.vprefix
            var_name_to_rads = {vprefix + str(vnum): [rads]
                    for vnum, rads in var_num_to_rads.items()}
            if self.num_samples:
                # send and receive from cloud, get obs_vec
                bitstrings = self.qc.run(self.term_to_exec[term],
                                         memory_map=var_name_to_rads)
                obs_vec = RigettiTools.obs_vec_from_bitstrings(
                        bitstrings, self.num_qbits, bs_is_array=True)

                # go from obs_vec to effective state vec
                counts_dict = StateVec.get_counts_from_obs_vec(self.num_qbits,
                                                               obs_vec)
                emp_pd = StateVec.get_empirical_pd_from_counts(self.num_qbits,
                                                               counts_dict)
                effective_st_vec = StateVec.get_emp_state_vec_from_emp_pd(
                        self.num_qbits, emp_pd)
            else:  # num_samples = 0
                sim = WavefunctionSimulator()
                pg = Program()
                # don't know how to declare number of qubits
                # so do this
                for k in range(self.num_qbits):
                    pg += I(k)
                for key, val in var_name_to_rads.items():
                    exec(key + '=' + str(val[0]))
                for line in self.translation_line_list:
                    line = line.strip('\n')
                    if line:
                        exec(line)
                bit_pos_to_xy_str =\
                    {bit: action for bit, action in term if action != 'Z'}
                RigettiTools.add_xy_meas_coda_to_program(
                    pg, bit_pos_to_xy_str)
                st_vec_arr = sim.wavefunction(pg).amplitudes
                st_vec_arr = st_vec_arr.reshape([2]*self.num_qbits)
                perm = list(reversed(range(self.num_qbits)))
                st_vec_arr = np.transpose(st_vec_arr, perm)
                effective_st_vec = StateVec(self.num_qbits, st_vec_arr)

            # add contribution to mean
            real_arr = self.get_real_vec(term)
            mean_val += coef*effective_st_vec.\
                    get_mean_value_of_real_diag_mat(real_arr)

        return mean_val


if __name__ == "__main__":
    def main():
        pass
